{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[![下载Notebook](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.3/resource/_static/logo_notebook.svg)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r2.3/tutorials/zh_cn/beginner/mindspore_save_load.ipynb)&emsp;\n",
    "[![下载样例代码](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.3/resource/_static/logo_download_code.svg)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r2.3/tutorials/zh_cn/beginner/mindspore_save_load.py)&emsp;\n",
    "[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.3/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/r2.3/tutorials/source_zh_cn/beginner/save_load.ipynb)\n",
    "\n",
    "[基本介绍](https://www.mindspore.cn/tutorials/zh-CN/r2.3/beginner/introduction.html) || [快速入门](https://www.mindspore.cn/tutorials/zh-CN/r2.3/beginner/quick_start.html) || [张量 Tensor](https://www.mindspore.cn/tutorials/zh-CN/r2.3/beginner/tensor.html) || [数据集 Dataset](https://www.mindspore.cn/tutorials/zh-CN/r2.3/beginner/dataset.html) || [数据变换 Transforms](https://www.mindspore.cn/tutorials/zh-CN/r2.3/beginner/transforms.html) || [网络构建](https://www.mindspore.cn/tutorials/zh-CN/r2.3/beginner/model.html) || [函数式自动微分](https://www.mindspore.cn/tutorials/zh-CN/r2.3/beginner/autograd.html) || [模型训练](https://www.mindspore.cn/tutorials/zh-CN/r2.3/beginner/train.html) || **保存与加载** || [使用静态图加速](https://www.mindspore.cn/tutorials/zh-CN/r2.3/beginner/accelerate_with_static_graph.html)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# 保存与加载\n",
    "\n",
    "上一章节主要介绍了如何调整超参数，并进行网络模型训练。在训练网络模型的过程中，实际上我们希望保存中间和最后的结果，用于微调（fine-tune）和后续的模型推理与部署，本章节我们将介绍如何保存与加载模型。"
   ]
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "%%capture captured_output\n",
    "# 实验环境已经预装了mindspore==2.3.0，如需更换mindspore版本，可更改下面 MINDSPORE_VERSION 变量\n",
    "!pip uninstall mindspore -y\n",
    "!export MINDSPORE_VERSION=2.3.0\n",
    "!pip install https://ms-release.obs.cn-north-4.myhuaweicloud.com/${MINDSPORE_VERSION}/MindSpore/unified/aarch64/mindspore-${MINDSPORE_VERSION}-cp39-cp39-linux_aarch64.whl --trusted-host ms-release.obs.cn-north-4.myhuaweicloud.com -i https://pypi.mirrors.ustc.edu.cn/simple"
   ]
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "# 查看当前 mindspore 版本\n",
    "!pip show mindspore"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import mindspore\n",
    "from mindspore import nn\n",
    "from mindspore import Tensor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def network():\n",
    "    model = nn.SequentialCell(\n",
    "                nn.Flatten(),\n",
    "                nn.Dense(28*28, 512),\n",
    "                nn.ReLU(),\n",
    "                nn.Dense(512, 512),\n",
    "                nn.ReLU(),\n",
    "                nn.Dense(512, 10))\n",
    "    return model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## 保存和加载模型权重\n",
    "\n",
    "保存模型使用`save_checkpoint`接口，传入网络和指定的保存路径："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "model = network()\n",
    "mindspore.save_checkpoint(model, \"model.ckpt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "要加载模型权重，需要先创建相同模型的实例，然后使用`load_checkpoint`和`load_param_into_net`方法加载参数。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[]"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = network()\n",
    "param_dict = mindspore.load_checkpoint(\"model.ckpt\")\n",
    "param_not_load, _ = mindspore.load_param_into_net(model, param_dict)\n",
    "print(param_not_load)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "> - `param_not_load`是未被加载的参数列表，为空时代表所有参数均加载成功。\n",
    "> - 当环境中安装有MindX DL（昇腾深度学习组件）6.0及以上版本时，默认启动MindIO加速CheckPoint功能，详情查看[MindIO介绍](https://www.hiascend.com/document/detail/zh/mindx-dl/60rc1/mindio/mindioacp/mindioacp001.html)。MindX DL在[此处](https://www.hiascend.com/developer/download/community/result?module=dl+cann)下载。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 保存和加载MindIR\n",
    "\n",
    "除Checkpoint外，MindSpore提供了云侧（训练）和端侧（推理）统一的[中间表示（Intermediate Representation，IR）](https://www.mindspore.cn/docs/zh-CN/r2.3/design/all_scenarios.html#中间表示mindir)。可使用`export`接口直接将模型保存为MindIR。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = network()\n",
    "inputs = Tensor(np.ones([1, 1, 28, 28]).astype(np.float32))\n",
    "mindspore.export(model, inputs, file_name=\"model\", file_format=\"MINDIR\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "> MindIR同时保存了Checkpoint和模型结构，因此需要定义输入Tensor来获取输入shape。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "已有的MindIR模型可以方便地通过`load`接口加载，传入`nn.GraphCell`即可进行推理。\n",
    "\n",
    "> `nn.GraphCell`仅支持图模式。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1, 10)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mindspore.set_context(mode=mindspore.GRAPH_MODE)\n",
    "\n",
    "graph = mindspore.load(\"model.mindir\")\n",
    "model = nn.GraphCell(graph)\n",
    "outputs = model(inputs)\n",
    "print(outputs.shape)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "MindSpore",
   "language": "python",
   "name": "mindspore"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  },
  "vscode": {
   "interpreter": {
    "hash": "8c9da313289c39257cb28b126d2dadd33153d4da4d524f730c81a4aaccbd2ca7"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
